import torch


def get_multi_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
    """
    Parameters:
    - embed_dim: int, the output dimension for each position.
    - grid_size: list[int], dimensions of the grid, e.g., [8, 12, 12] for a 3D grid.
    - add_cls_token: bool, if True, adds a zero embedding vector for a class token at the beginning.
    
    Returns:
    - torch.Tensor, positional embeddings with shape [1, np.prod(grid_size) + (1 if add_cls_token else 0), embed_dim].
    """
    grid_dim = len(grid_size)
    assert grid_dim >= 2, "Grid_size should be at least 2D"
    assert embed_dim % (grid_dim * 2) == 0, "Each dimension has 2 channels (sin, cos)"

    # Creates a grid of coordinates (e.g., 3D coordinates for each point in an 8x12x12 grid).
    grid = torch.meshgrid(*[torch.arange(s, dtype=torch.float32) for s in grid_size], indexing='ij')
    grid = torch.stack(grid, dim=0)  # Stacks to create a single tensor representing all coordinates.

    pos_embed = get_multi_sincos_pos_embed_from_grid(embed_dim, grid)

    if add_cls_token:
        pos_embed = torch.concatenate([torch.zeros([1, embed_dim]), pos_embed], dim=0)
    
    return pos_embed.unsqueeze(0)

def get_multi_sincos_pos_embed_from_grid(embed_dim, grid):
    grid_dim = len(grid.shape) - 1  # Number of grid dimensions.
    # Generate embeddings for each dimension and concatenate them.
    emb = [get_1d_sincos_pos_embed_from_grid(embed_dim // grid_dim, grid[i]) for i in range(grid.shape[0])]
    emb = torch.concatenate(emb, dim=1) # (T*H*W, D/4) -> (T*H*W, D)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = torch.arange(embed_dim // 2, dtype=torch.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = torch.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = torch.sin(out) # (M, D/2)
    emb_cos = torch.cos(out) # (M, D/2)

    emb = torch.concatenate([emb_sin, emb_cos], dim=1)  # (M, D)
    return emb
